-
Notifications
You must be signed in to change notification settings - Fork 9.7k
TP SP examples improvement #1354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ Deploy Preview for pytorch-examples-preview canceled.
|
output.sum().backward() | ||
optimizer.step() | ||
inp = torch.rand(4, 10, device=device_type) | ||
comm_mode = CommDebugMode() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work on non cuda devices? Would be great to share some local logs of your tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gladly. Please see attached logs for H100.
Starting PyTorch TP example on rank 3.
Starting PyTorch TP example on rank 0.
06/16/2025 05:55:00 PM Device Mesh created: device_mesh=DeviceMesh('cuda', [0, 1, 2, 3])
Starting PyTorch TP example on rank 2.
Starting PyTorch TP example on rank 1.
model ToyModel(
(in_proj): Linear(in_features=10, out_features=32, bias=True)
(relu): ReLU()
(out_proj): Linear(in_features=32, out_features=5, bias=True)
)
06/16/2025 05:55:03 PM Tensor Parallel training starting...
06/16/2025 05:55:03 PM Tensor Parallel iter 0 completed
rank3 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_reduce: 1
BACKWARD PASS
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_reduce: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32]), torch.Size([32])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32, 10]), torch.Size([32, 10])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.all_reduce: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5]), torch.Size([5])]
sharding: [(Replicate(),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5, 32]), torch.Size([5, 32])]
sharding: [(Shard(dim=1),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
06/16/2025 05:55:03 PM Tensor Parallel iter 1 completed
rank0 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_reduce: 1
BACKWARD PASS
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_reduce: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32]), torch.Size([32])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32, 10]), torch.Size([32, 10])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.all_reduce: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5]), torch.Size([5])]
sharding: [(Replicate(),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5, 32]), torch.Size([5, 32])]
sharding: [(Shard(dim=1),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
rank2 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_reduce: 1
BACKWARD PASS
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_reduce: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32]), torch.Size([32])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32, 10]), torch.Size([32, 10])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.all_reduce: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5]), torch.Size([5])]
sharding: [(Replicate(),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5, 32]), torch.Size([5, 32])]
sharding: [(Shard(dim=1),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
rank1 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_reduce: 1
BACKWARD PASS
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_reduce: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32]), torch.Size([32])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32, 10]), torch.Size([32, 10])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.all_reduce: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5]), torch.Size([5])]
sharding: [(Replicate(),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5, 32]), torch.Size([5, 32])]
sharding: [(Shard(dim=1),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
06/16/2025 05:55:03 PM Tensor Parallel iter 2 completed
06/16/2025 05:55:03 PM Tensor Parallel iter 3 completed
06/16/2025 05:55:03 PM Tensor Parallel iter 4 completed
06/16/2025 05:55:03 PM Tensor Parallel iter 5 completed
06/16/2025 05:55:03 PM Tensor Parallel iter 6 completed
06/16/2025 05:55:04 PM Tensor Parallel iter 7 completed
06/16/2025 05:55:04 PM Tensor Parallel iter 8 completed
06/16/2025 05:55:04 PM Tensor Parallel iter 9 completed
06/16/2025 05:55:04 PM Tensor Parallel training completed!
[rank0]:[W616 17:55:04.791527408 ProcessGroupNCCL.cpp:1516] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
Starting PyTorch Sequence Parallel example on rank 0.
06/16/2025 05:53:21 PM Device Mesh created: device_mesh=DeviceMesh('cuda', [0, 1, 2, 3])
Starting PyTorch Sequence Parallel example on rank 3.
Starting PyTorch Sequence Parallel example on rank 2.
Starting PyTorch Sequence Parallel example on rank 1.
model ToyModel(
(in_proj): Linear(in_features=10, out_features=32, bias=True)
(relu): ReLU()
(out_proj): Linear(in_features=32, out_features=5, bias=True)
)
06/16/2025 05:53:24 PM Sequence Parallel training starting...
rank2 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.reduce_scatter_tensor: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
06/16/2025 05:53:25 PM Sequence Parallel iter 0 completed
rank0 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.reduce_scatter_tensor: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
rank1 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.reduce_scatter_tensor: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
rank3 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.reduce_scatter_tensor: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
06/16/2025 05:53:25 PM Sequence Parallel iter 1 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 2 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 3 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 4 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 5 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 6 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 7 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 8 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 9 completed
06/16/2025 05:53:25 PM Sequence Parallel training completed!
[rank0]:[W616 17:53:25.948217933 ProcessGroupNCCL.cpp:1516] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I meant on non CUDA devices, as does this API work if you use MPS or CPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch,accelerator works for cuda and non-cuda GPUs and accelerators. CommDebugMode is also a PyTorch feature, so should work for all devices. If not, that would be a bug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@msaroufim . if there is no more question, could it be merged ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please attach logs confirming this works on CPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@msaroufim , I have only used GPU's for these kind of work. The accelerator api does not support CPU's also. Also do not know whether TP and SP are supported on CPU's. If so, what distributed backend is used. The original code also would not work on CPUs as far as I can tell. In summary, these two examples were not written for CPUs. Adding CPU support will be a very significant change, if at all possible, as far as I can tell.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I'm confused by the goal of this PR overall
- Why merge a device agnostic API if the code is only expected to work on a single device? If that's the case then keeping cuda is actually clearer
- I'm not sure why comm_debuug mode is introduced and why it should be default behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@msaroufim , great questions. Let me address those.
-
There are non-cuda GPU's/accelerators ( e.g. XPU, MTIA, HPU, etc.). It is a write once, run anywhere interface. Hence, model code would run in any of the supported accelerators without requiring surgery.
-
As these are distributed example codes, a way to see what is happening in the distributed layer should be very informative. It can be bracketed by an input option also, if that is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok yeah this makes sense, we can merge this if you fix breakage in CI job and make the Comm Debug optional in another PR
Looks like the failing cuda test below ( [Run Distributed Examples / test (pull_request) is done with a relatively old version of PyTorch ( torch==2.4.0.dev20240605+cu11 ). The upcoming release is 2,8 . |
Changing cuda to accelerator, adding ConmDebugMode to tensor_parallel_example.py, sequence_parallel_example.py, and log_utils.py .